package com.prawn.monitor.controller;

import com.prawn.monitor.pojo.Warning;
import com.prawn.monitor.service.WarningService;
import net.sf.json.JSONObject;

import com.prawn.monitor.arima.ARIMA;
import com.prawn.monitor.pojo.MeteorologicalData;
import com.prawn.monitor.pojo.WaterData;
import com.prawn.monitor.service.MeteorologicalDataService;

import com.prawn.monitor.service.WaterDataService;
import entity.Result;
import entity.StatusCode;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.client.RestTemplate;

import java.time.LocalDateTime;
import java.util.*;

/**
 * 控制器层
 * @author Administrator
 */
@RestController
@CrossOrigin
@RequestMapping("/datarecord/forecast")
public class ForecastController {

    @Autowired
    private MeteorologicalDataService meteorologicalDataService;

    @Autowired
    private WaterDataService waterDataService;

    @Autowired
    private WarningService warningService;
    /**
     * 1、查出该基地的所有设备，将设备名字放在第一个下拉框。
     * 2、第二个下拉框放设备的检测项，气象设备与水质设备不同，因为前端展示的是中文检测名，而数据库
     *    是字母，之间的装换要统一下，例如前端显示“温度”，则传给后端则是“airTemperature”
     * 3、第三个下拉框放算法模型，水质算法有svm,LR,DT,RF四种；气象设备有两种arima,lstmrnn
     * 4、第四个第五个下拉框放起始时间和结束时间，时间格式只能是  2019-10-08 04:47:47
     */

    /**
     * 基于ARIMA的方法，预测对应的数据
     *
     * @return LayuiDataGridResult
     */
    @RequestMapping(value = "/arima/{page}/{size}",method = RequestMethod.POST)
    @ResponseBody
    public Result getPredictedDataByARIMA( @RequestBody Map<String,String> condiction, @PathVariable int page, @PathVariable int size) {



        //获取查询数据到的  后面优化为只拿key,不拿全部字段
        List<MeteorologicalData> dataRecordList = meteorologicalDataService.findSearch(condiction,page,size).getContent();
        String checkItemName=condiction.get("checkItemName");


        //读取数据，将数据封装到数组中
        ArrayList< Double> arraylist = new ArrayList<>();
        ArrayList< Double> predictValueList = new ArrayList<>();

        //构建训练集dataSet 349 --> 100
        for (int i = dataRecordList.size() - 1; i >= 0; i--) {
            arraylist.add(dataRecordList.get(i).getCheckItem(checkItemName));
        }

        //数据样本条数大于0的时候才进行预测
        if (dataRecordList.size() > 0) {
            int num = dataRecordList.size() / 5;
            for (int i = 0; i < num; i++) {    //预测完一次就进行平移
                double[] dataArray = new double[arraylist.size() - num + i];//预测的样本数据，没预测完一个数据，数组就会扩充一个数据

                //预测完一个数据，窗口向右平移
                for (int j = 0; j < arraylist.size() - num + i; j++) {
                    dataArray[j] = arraylist.get(j);
                }

                //进行预测
                ARIMA arima = new ARIMA(dataArray);
                int[] model = arima.getARIMAmodel();
                predictValueList.add(arima.aftDeal(arima.predictValue(model[0], model[1])));
                arraylist.add(arima.aftDeal(arima.predictValue(model[0], model[1])) * 1.0);
            }
        }
        //这里是判断阈值与预测结果之间的大小
        boolean warn = false;  //判断预测结果是否超出阈值，是就返回true
        Map whereMap = new HashMap();
        whereMap.put("equipmentId",condiction.get("equipmentId"));
        whereMap.put("channelName",checkItemNameChanage(checkItemName));
        List<Warning> search = warningService.findSearch(whereMap);
        if(!search.isEmpty()){
            Warning warning = search.get(0);
            double maxV = warning.getMaxValue();
            double minV = warning.getMinValue();
            for(int i=0;i<predictValueList.size();i++){
                if(predictValueList.get(i)>maxV||predictValueList.get(i)<minV){
                    warn = true;
                    break;
                }
            }
        }

        Map<String, Object> map = new HashMap<>();

        map.put("orgindata", arraylist);
        map.put("predictdata", predictValueList);
        map.put("warn", warn);

        Result responseMessage = new Result();
        responseMessage.setCode(200);
        responseMessage.setData(map);
        return responseMessage;

    }

    /**
     * 基于LSTM循环神经网络预测
     */
    @RequestMapping(value = "/lstmrnn/{page}/{size}",method = RequestMethod.POST)
    @ResponseBody
    public Result getPredictedDataByLSTMRNN( @RequestBody Map<String,String> condiction, @PathVariable int page, @PathVariable int size) {

        List<MeteorologicalData> dataRecordList = meteorologicalDataService.findSearch(condiction,page,size).getContent();
        String checkItemName=condiction.get("checkItemName");


        //读取数据，将数据封装到数组中
        ArrayList<Double> arraylist = new ArrayList<Double>();
        ArrayList<Double> predictValueList = new ArrayList<>();
        double[] inputTrainDataArr = new double[dataRecordList.size() - 100];//训练输入数组
        double[] outputTrainDataArr = new double[dataRecordList.size() - 100];//训练输出数组
        double[] inputTestDataArr = new double[99];//测试输入数组
        double[] outputTestDataArr = new double[99];//测试输出数组


        //构建训练集dataSet 349 --> 100
        for (int i = dataRecordList.size() - 1, j = 0; i >= 100 && j < inputTrainDataArr.length; i--, j++) {
            inputTrainDataArr[j] = dataRecordList.get(i).getCheckItem(checkItemName);
        }
        for (int i = dataRecordList.size() - 1, j = 0; i >= 100 && j < inputTrainDataArr.length; i--, j++) {
            outputTrainDataArr[j] = dataRecordList.get(i-1).getCheckItem(checkItemName);
        }
        double[][][] trainDataINDInputBox = {{inputTrainDataArr}};
        double[][][] trainDataINDOutputBox = {{outputTrainDataArr}};
        INDArray trainDataINDInput = Nd4j.create(trainDataINDInputBox);
        INDArray trainDataINDoutput = Nd4j.create(trainDataINDOutputBox);
        DataSet trainData = new DataSet();
        trainData.setFeatures(trainDataINDInput);
        trainData.setLabels(trainDataINDoutput);


        //构建测试集dataSet 100--1
        for (int j = 0, i = dataRecordList.size() - 250; i > 0 && j < outputTestDataArr.length; i--, j++) {
            inputTestDataArr[j] =dataRecordList.get(i).getCheckItem(checkItemName);
        }
        for (int j = 0, i = dataRecordList.size() - 251; i > 0 && j < outputTestDataArr.length; i--, j++) {
            outputTestDataArr[j] =dataRecordList.get(i-1).getCheckItem(checkItemName);
        }
        double[][][] testDataINDInputBox = {{inputTestDataArr}};
        double[][][] testDataINDOutputBox = {{outputTestDataArr}};
        INDArray testDataINDInput = Nd4j.create(testDataINDInputBox);
        INDArray testDataINDoutput = Nd4j.create(testDataINDOutputBox);
        DataSet testData = new DataSet();
        testData.setFeatures(testDataINDInput);
        testData.setLabels(testDataINDoutput);


        //将数据及映射到0~1的范围
        NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
        normalizer.fitLabel(true);
        normalizer.fit(trainData);              //Collect training data statistics

        normalizer.transform(trainData);
        normalizer.transform(testData);

        //配置LSTM的RNN网络
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(140)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(0.0015, 0.9))
                .list()
                .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(1).nOut(5)
                        .build())
                .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .activation(org.nd4j.linalg.activations.Activation.IDENTITY).nIn(5).nOut(1).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        //开始训练
        int nEpochs = 200;
        System.out.println("开始训练...");
        for (int i = 0; i < nEpochs; i++) {
            net.fit(trainData);
        }
        System.out.println("训练结束");

        //开始预测---拟合数据，并非真实预测
        net.rnnTimeStep(trainData.getFeatures());
        INDArray predicted = net.rnnTimeStep(testData.getFeatures());

        //将映射的数据还原到初始数据
        normalizer.revertLabels(predicted);
        for (int i = 0; i < predicted.length(); i++) {
            predictValueList.add(predicted.getDouble(0, 0, i));
        }

        //真实预测数据
        double[] realPredictData = new double[50];
        realPredictData[0] = predicted.getDouble(0, 0, predicted.length() - 1);
        for (int i = 1; i < realPredictData.length; i++) {
            INDArray indArray = Nd4j.create(new double[]{realPredictData[i - 1]}, new int[]{1, 1});
            normalizer.transform(indArray);
            INDArray realPredicted = net.rnnTimeStep(indArray);
            normalizer.revertLabels(realPredicted);
            realPredictData[i] = realPredicted.getDouble(0);
        }
        //这里是判断阈值与预测结果之间的大小
        boolean warn = false; //判断预测结果是否超出阈值，是就返回true
        Map whereMap = new HashMap();
        whereMap.put("equipmentId",condiction.get("equipmentId"));
        whereMap.put("channelName",checkItemNameChanage(checkItemName));
        List<Warning> search = warningService.findSearch(whereMap);
        if(!search.isEmpty()){
            Warning warning = search.get(0);
            double maxV = warning.getMaxValue();
            double minV = warning.getMinValue();
            for(int i=0;i<predictValueList.size();i++){
                if(predictValueList.get(i)>maxV||predictValueList.get(i)<minV){
                    warn = true;
                    break;
                }
            }
        }

        //返回数据

//        List<MeteorologicalData> orgindata = dataRecordList.subList(dataRecordList.size() - 250, dataRecordList.size() - 1);
//        List<MeteorologicalData> testdata = dataRecordList.subList(0, dataRecordList.size() - 250);
        Map<String, Object> map = new HashMap<>();
//        map.put("orgindata", orgindata);
        map.put("orgindata", inputTrainDataArr);
        map.put("predictdata", predictValueList);
        map.put("realPredictData", realPredictData);
        map.put("warn", warn);

        Result responseMessage = new Result();
        responseMessage.setCode(200);
        responseMessage.setData(map);
        responseMessage.setMessage("预测成功");
        return responseMessage;


    }
    /**
     * 水质预测，调用第三方接口
     * 需要的参数  检测项,设备id,算法模型,page,size
     */
    @RequestMapping(value = "/water/{algorithm}/{page}/{size}",method = RequestMethod.POST)
    @ResponseBody
    public Result getPredictedDataBySVN( @RequestBody Map<String,String> condiction, @PathVariable String algorithm,@PathVariable int page, @PathVariable int size) {
        List<WaterData> dataRecordList = waterDataService.findSearch(condiction,page,size).getContent();
        String checkItemName=condiction.get("checkItemName");

        //读取数据，将数据封装到数组中
        ArrayList< ArrayList> arraylist = new ArrayList<>();
        ArrayList< Double> predictValueList = new ArrayList<>();
        //该集合的个数对应返回预测结果的个数
        ArrayList< ArrayList> realPredictValueList = new ArrayList<>();

        //构建训练集
        for (int i = dataRecordList.size() - 1; i >= 1; i--) {
            if(dataRecordList.get(i).getCheckItem(checkItemName)==null||dataRecordList.get(i-1).getCheckItem(checkItemName)==null){
                continue;
            }
            ArrayList<Double> arr1 = new ArrayList();
            arr1.add(dataRecordList.get(i).getCheckItem(checkItemName));
            arraylist.add(arr1);

            predictValueList.add(dataRecordList.get(i-1).getCheckItem(checkItemName));
        }
        for(int i = arraylist.size()-1;i>arraylist.size()-11;i--){
            ArrayList<Double> arr1 = new ArrayList();
            arr1.add((Double) arraylist.get(i).get(0));
            realPredictValueList.add(arr1);
        }

        String url = "http://106.75.132.85:9010/"+algorithm;

        JSONObject postData = new JSONObject();
        postData.put("feature",arraylist+""); //特征
        postData.put("target",predictValueList+"");          //目标
        postData.put("prediction",realPredictValueList+"");        //预测

        RestTemplate restTemplate = new RestTemplate();
        String s = restTemplate.postForObject(url,postData,String.class);

        //{"prediction": [2.0, 4.0], "message": "success"}
        //因为JSONObject的方法一直出问题，所以只能暂时改成String的形式进行接收
        s = s.substring(s.indexOf("[")+1,s.lastIndexOf("]"));
        String[] strs = s.split(",");
        ArrayList<Double> prediction = new ArrayList<>();
        for (String str : strs) {
            prediction.add(Double.parseDouble(str));
        }
        //封装返回的数据
        Map<String, Object> map = new HashMap<>();
        map.put("orgindata", predictValueList);
        map.put("predictdata", prediction);

        //这里是判断阈值与预测结果之间的大小
        boolean warn = false; //判断预测结果是否超出阈值，是就返回true
        Map whereMap = new HashMap();
        whereMap.put("equipmentId",condiction.get("equipmentId"));
        whereMap.put("channelName",checkItemNameChanage(checkItemName));
        List<Warning> search = warningService.findSearch(whereMap);
        if(!search.isEmpty()){
            Warning warning = search.get(0);
            String str = prediction.toString().substring(1, prediction.toString().length() - 1);
            String[] strings = str.split(",");
            for (int i = 0; i < strings.length; i++) {
                if (Double.parseDouble(strings[i]) > warning.getMaxValue() || Double.parseDouble(strings[i]) < warning.getMinValue()) {
                    warn=true;
                    break;
                }
            }
        }
        map.put("warn", warn);
        return new Result(true,200,"预测成功",map);
    }

    /**
     * 这里主要用于检测项与通道名称得转换
     * checkItemName——>channelName
     *
     */
    public static String checkItemNameChanage(String checkItemName){
        //气象通道转换
        if(checkItemName.equals("electric_energy")){
            return "电能";
        }else if(checkItemName.equals("illumination")){
            return "光照";
        }else if(checkItemName.equals("wind_speed")){
            return "风速";
        }else if(checkItemName.equals("wind_direct")){
            return "风向";
        }else if(checkItemName.equals("air_temperature")){
            return "气温";
        }else if(checkItemName.equals("humidity")){
            return "湿度";
        }else if(checkItemName.equals("rain")){
            return "雨量";
        }else if(checkItemName.equals("soil_temperature")){
            return "土温";
        //水质通道转换
        }else if(checkItemName.equals("dissolvedOxygen")){
            return "溶解氧";
        }else if(checkItemName.equals("waterTemperature")){
            return "水温";
        }else if(checkItemName.equals("phValue")){
            return "pH";
        }else if(checkItemName.equals("ammoniaNitrogen")){
            return "氨氮";
        }else if(checkItemName.equals("conductivity")){
            return "电导率";
        }else if(checkItemName.equals("turbidity")){
            return "浊度";
        }else if(checkItemName.equals("permanganateIndex")){
            return "高猛酸盐指数";
        }else if(checkItemName.equals("phosphorus")){
            return "总磷";
        }else if(checkItemName.equals("nitrogen")){
            return "总氮";
        }else if(checkItemName.equals("chlorophyll")){
            return "叶绿素α";
        }else if(checkItemName.equals("algalDensity")){
            return "藻密度";
        }else if(checkItemName.equals("waterLevel")){
            return "水位";
        }else{
            return "";
        }
    }


}